{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8KcfZ0oRSJ-I"
      },
      "source": [
        "**Copyright 2019 The Sonnet Authors. All Rights Reserved.**\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "   http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PUWojJ3iTgKw"
      },
      "source": [
        "# Introduction\n",
        "\n",
        "This tutorial assumes you have already completed (and understood!) the Sonnet 2 \"Hello, world!\" example (MLP on MNIST).\n",
        "\n",
        "In this tutorial, we're going to scale things up with a bigger model and bigger dataset, and we're going to distribute the computation across multiple devices."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "y4TOXSlnTcSB"
      },
      "source": [
        "# Preamble"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "9f811iaTTbmI"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "assert sys.version_info \u003e= (3, 6), \"Sonnet 2 requires Python \u003e=3.6\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "5R3-sAFoTiyB"
      },
      "outputs": [],
      "source": [
        "!pip install dm-sonnet tqdm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "RdSVvLnkTmWY"
      },
      "outputs": [],
      "source": [
        "import sonnet as snt\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "uPtNKJwhTnze"
      },
      "outputs": [],
      "source": [
        "print(\"TensorFlow version: {}\".format(tf.__version__))\n",
        "print(\"    Sonnet version: {}\".format(snt.__version__))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Ab-nCz85sIkh"
      },
      "source": [
        "Finally lets take a quick look at the GPUs we have available:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "elmPgnWJsIUI"
      },
      "outputs": [],
      "source": [
        "!grep Model: /proc/driver/nvidia/gpus/*/information | awk '{$1=\"\";print$0}'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-T4gfAHytWrk"
      },
      "source": [
        "# Distribution strategy\n",
        "\n",
        "We need a strategy to distribute our computation across several devices. Since Google Colab only provides a single GPU we'll split it into four virtual GPUs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "colab_type": "code",
        "id": "14VQpLastTlW",
        "outputId": "f8d1bcf2-ca45-43bb-8485-5e6589f0b2e3"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]"
            ]
          },
          "execution_count": 8,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "physical_gpus = tf.config.experimental.list_physical_devices(\"GPU\")\n",
        "physical_gpus"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "aoEY-15BtYaf"
      },
      "outputs": [],
      "source": [
        "tf.config.experimental.set_virtual_device_configuration(\n",
        "    physical_gpus[0],\n",
        "    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)] * 4\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 85
        },
        "colab_type": "code",
        "id": "risJlPoDtx6-",
        "outputId": "522e0953-de6f-47f8-f06d-5dc14ad32465"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:0', device_type='GPU'),\n",
              " LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:1', device_type='GPU'),\n",
              " LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:2', device_type='GPU'),\n",
              " LogicalDevice(name='/job:localhost/replica:0/task:0/device:GPU:3', device_type='GPU')]"
            ]
          },
          "execution_count": 12,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "gpus = tf.config.experimental.list_logical_devices(\"GPU\")\n",
        "gpus"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WhdyolCktTGU"
      },
      "source": [
        "When using Sonnet optimizers, we must use either `Replicator` or `TpuReplicator` from `snt.distribute`, or we can use `tf.distribute.OneDeviceStrategy`. `Replicator` is equivalent to `MirroredStrategy` and `TpuReplicator` is equivalent to `TPUStrategy`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "J82G9zqxtb1c"
      },
      "outputs": [],
      "source": [
        "strategy = snt.distribute.Replicator(\n",
        "    [\"/device:GPU:{}\".format(i) for i in range(4)],\n",
        "    tf.distribute.ReductionToOneDevice(\"GPU:0\"))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "smiRXbgmT9SD"
      },
      "source": [
        "# Dataset\n",
        "\n",
        "Basically the same as the MNIST example, but this time we're using CIFAR-10. CIFAR-10 contains 32x32 pixel color images in 10 different classes (airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "1xOwe9y_T_A4"
      },
      "outputs": [],
      "source": [
        "# NOTE: This is the batch size across all GPUs.\n",
        "batch_size = 100 * 4\n",
        "\n",
        "def process_batch(images, labels):\n",
        "  images = tf.cast(images, dtype=tf.float32)\n",
        "  images = ((images / 255.) - .5) * 2.\n",
        "  return images, labels\n",
        "\n",
        "def cifar10(split):\n",
        "  dataset = tfds.load(\"cifar10\", split=split, as_supervised=True)\n",
        "  dataset = dataset.map(process_batch)\n",
        "  dataset = dataset.batch(batch_size)\n",
        "  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)\n",
        "  dataset = dataset.cache()\n",
        "  return dataset\n",
        "\n",
        "cifar10_train = cifar10(\"train\").shuffle(10)\n",
        "cifar10_test = cifar10(\"test\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BT2fP7jjpUGe"
      },
      "source": [
        "# Model & Optimizer\n",
        "\n",
        "Conveniently, there is a pre-built model in `snt.nets` designed specifically for this dataset.\n",
        "\n",
        "We must build our model and optimizer within the strategy scope, to ensure that any variables created are distributed correctly. Alternatively, we could enter the scope for the entire program using `tf.distribute.experimental_set_strategy`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vEk6eJUPpWB-"
      },
      "outputs": [],
      "source": [
        "learning_rate = 0.1\n",
        "\n",
        "with strategy.scope():\n",
        "  model = snt.nets.Cifar10ConvNet()\n",
        "  optimizer = snt.optimizers.Momentum(learning_rate, 0.9)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YMgcImzms2V0"
      },
      "source": [
        "# Training the model\n",
        "\n",
        "The Sonnet optimizers are designed to be as clean and simple as possible. They do not contain any code to deal with distributed execution. It therefore requires a few additional lines of code.\n",
        "\n",
        "We must aggregate the gradients calculated on the different devices. This can be done using `ReplicaContext.all_reduce`.\n",
        "\n",
        "Note that when using `Replicator` / `TpuReplicator` it is the user's responsibility to ensure that the values remain identical in all replicas."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "RUkuAJsxsjvt"
      },
      "outputs": [],
      "source": [
        "def step(images, labels):\n",
        "  \"\"\"Performs a single training step, returning the cross-entropy loss.\"\"\"\n",
        "  with tf.GradientTape() as tape:\n",
        "    logits = model(images, is_training=True)[\"logits\"]\n",
        "    loss = tf.reduce_mean(\n",
        "        tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,\n",
        "                                                       logits=logits))\n",
        "\n",
        "  grads = tape.gradient(loss, model.trainable_variables)\n",
        "\n",
        "  # Aggregate the gradients from the full batch.\n",
        "  replica_ctx = tf.distribute.get_replica_context()\n",
        "  grads = replica_ctx.all_reduce(\"mean\", grads)\n",
        "\n",
        "  optimizer.apply(grads, model.trainable_variables)\n",
        "  return loss\n",
        "\n",
        "@tf.function\n",
        "def train_step(images, labels):\n",
        "  per_replica_loss = strategy.run(step, args=(images, labels))\n",
        "  return strategy.reduce(\"sum\", per_replica_loss, axis=None)\n",
        "\n",
        "def train_epoch(dataset):\n",
        "  \"\"\"Performs one epoch of training, returning the mean cross-entropy loss.\"\"\"\n",
        "  total_loss = 0.0\n",
        "  num_batches = 0\n",
        "\n",
        "  # Loop over the entire training set.\n",
        "  for images, labels in dataset:\n",
        "    total_loss += train_step(images, labels).numpy()\n",
        "    num_batches += 1\n",
        "\n",
        "  return total_loss / num_batches\n",
        "\n",
        "cifar10_train_dist = strategy.experimental_distribute_dataset(cifar10_train)\n",
        "\n",
        "for epoch in range(20):\n",
        "  print(\"Training epoch\", epoch, \"...\", end=\" \")\n",
        "  print(\"loss :=\", train_epoch(cifar10_train_dist))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cHbiCxJ81wZo"
      },
      "source": [
        "# Evaluating the model\n",
        "\n",
        "Note the use of the `axis` parameter with `strategy.reduce` to reduce across the batch dimension."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "hxL8-GOB1yAq"
      },
      "outputs": [],
      "source": [
        "num_cifar10_test_examples = 10000\n",
        "\n",
        "def is_predicted(images, labels):\n",
        "  logits = model(images, is_training=False)[\"logits\"]\n",
        "  # The reduction over the batch happens in `strategy.reduce`, below.\n",
        "  return tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.int32)\n",
        "\n",
        "cifar10_test_dist = strategy.experimental_distribute_dataset(cifar10_test)\n",
        "\n",
        "@tf.function\n",
        "def evaluate():\n",
        "  \"\"\"Returns the top-1 accuracy over the entire test set.\"\"\"\n",
        "  total_correct = 0\n",
        "\n",
        "  for images, labels in cifar10_test_dist:\n",
        "    per_replica_correct = strategy.run(is_predicted, args=(images, labels))\n",
        "    total_correct += strategy.reduce(\"sum\", per_replica_correct, axis=0)\n",
        "\n",
        "  return tf.cast(total_correct, tf.float32) / num_cifar10_test_examples\n",
        "\n",
        "print(\"Testing...\", end=\" \")\n",
        "print(\"top-1 accuracy =\", evaluate().numpy())"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "Multi-GPU training with Sonnet 2",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
